import scipy
import numpy as np 
from numpy.lib.scimath import sqrt as csqrt

def empirical_Volterra(n, max_iter, gamma, eiglist, zeta, Delta, R, R_tilde):
  '''
  Computes the values of the empirical Volterra equation

  Parameters
  -----------
    Inputs
    -------
    int       n          : number of data points
    int       max_iter   : number of values computed
    float     gamma      : learning_rate or stepsize
    np.array  eiglist    : numpy array containing eigenvalues of covariance matrix AA^T
    float     zeta       : batch_size/n
    float     Delta      : momentum parameter
    float     R          : normalization constant for signal (see Problem Setting in paper)
    float     R_tilde    : normalization constant for noise (see Problem Setting in paper)

    Output
    ------
    list ps0_list    : list of function values of empirical volterra equation (ref eqn' in paper)

  '''

  def Omega(gamma, x, Delta, zeta):
      return (1- zeta*gamma*x + Delta)

  def Omega_sqd(gamma, x, Delta, zeta):
      return (1- zeta*gamma*x + Delta)**2

  def lambda_k(omega_sqd, Delta, k):
      discriminant = csqrt(omega_sqd * (omega_sqd - 4*Delta))
      if k==2:
          return ((-2*Delta + omega_sqd) + discriminant)/2
      else:
          return ((-2*Delta + omega_sqd) - discriminant)/2

  def kappa_k(k, x, Delta, gamma, zeta):
    omega_sqd = Omega_sqd(gamma, x, Delta, zeta)
    omega = Omega(gamma, x, Delta, zeta)
    if k == 2:
      lambda_2 = lambda_k(omega_sqd, Delta, k=2)
      return (lambda_2*omega)/(lambda_2 + Delta)
    else:
      lambda_3 = lambda_k(omega_sqd, Delta, k=3)
      return (lambda_3*omega)/(lambda_3 + Delta)

  def hk(gamma, Delta, eiglist, zeta, omega_sqd_list, k, t):
    Delta = np.maximum(Delta, 1e-8).astype('complex64')
    lambda2_list = lambda_k(omega_sqd_list, Delta, k=2)
    lambda3_list = lambda_k(omega_sqd_list, Delta, k=3)
    term1 = (1/n) * (2*(eiglist**k)/(omega_sqd_list - 4*Delta))
    term2 = -Delta*gamma*zeta*eiglist*(Delta**(t))
    term3 = (0.5*(kappa_k(2,eiglist,Delta,gamma,zeta) - Delta)**2) * (lambda2_list ** (t))
    term4 = (0.5*(kappa_k(3,eiglist,Delta,gamma,zeta) - Delta)**2) * (lambda3_list ** (t))
    big_bracket = term2 +  term3 + term4
    return np.dot(term1, big_bracket)
      
  
  def H2(omega_sqd_list, eiglist, t):
    lambda2_list = lambda_k(omega_sqd_list, Delta, k=2)
    lambda3_list = lambda_k(omega_sqd_list, Delta, k=3)
    term1 = (1/n) * (2*eiglist**2)/(omega_sqd_list - 4*Delta)
    term2 = -Delta**(t+1)
    term3 = (0.5) *(lambda2_list**(t+1))
    term4 = (0.5) *(lambda3_list**(t+1))
    big_bracket = term2 + term3 + term4
    return np.dot(term1, big_bracket)

  def psi0(H_2, h0, h1,R, R_tilde, gamma,zeta):
    psi0_list = np.zeros(max_iter)
    H2_rev = H_2[::-1]
    for t in range(max_iter):
      term1 = 0.5 *  R * h1[t]
      term2 = 0.5 * R_tilde * (h0[t])
      if t == 0:
        term3 = 0
      else:
        term3 = (gamma**2)*zeta*(1-zeta)*np.dot(H2_rev[-t:], psi0_list[:(t)])
      #terms 1,2,and 3 are reals so discarding complex part has no affect
      psi0_list[t] = np.real(term1 + term2 + term3) #cast to reals to turn off python warning
    return psi0_list
  
  

  omega_sqd_list = Omega_sqd(gamma, eiglist, Delta, zeta).astype("complex128")
  h0 = np.zeros(max_iter).astype("complex128")
  h1 = np.zeros(max_iter).astype("complex128")
  H_2 = np.zeros(max_iter).astype("complex128")
  for t in range(max_iter):
    h0[t] = hk(gamma, Delta, eiglist, zeta, omega_sqd_list, k=0, t=t)
    h1[t] = hk(gamma, Delta, eiglist, zeta, omega_sqd_list, k=1, t=t)
    H_2[t] = H2(omega_sqd_list, eiglist, t=t)

  Omega_list= Omega(gamma=gamma, x=eiglist, Delta=Delta, zeta=zeta)
  kappa2_list = kappa_k(k=2, x=eiglist, Delta=Delta, gamma=gamma, zeta=zeta)
  lambda2_list = lambda_k(omega_sqd=Omega_list**2, Delta=Delta, k=2)
  psi0_list = psi0(H_2, h0, h1, R, R_tilde, gamma,zeta)
  return psi0_list
